Java并发编程笔记3-ThreadLocal 和InheritableThreadLocal 原理

预览(本文逻辑)

这篇文章的逻辑很简单:

  • 首先是通过之前文章我们知道锁不仅很重而且使用门槛也高,所以能不用锁我们就想不用,所以我们想实现一个功能,就是让共享变量在每个线程访问的时候访问的是当前线程内部的一个变量(这样说可能不太准确,但意思是这个),有了这个需求,就有了ThreadLocal类。
  • 然后我们深入的分析了一下ThreadLocal的原理,知道ThreadLocal是怎样实现这样的功能。
  • 实现这个功能以后,又有一个需求出来了,就是如果子线程想访问父线程的ThreadLocal中的变量怎么办,正常这个是不能访问的,也是不应该访问的,但就是有这样奇奇怪怪的需求,那这个需求能实现吗?所以就有了InheritableThreadLocal 类。
  • 最后我们再分析了 InheritableThreadLocal 怎么让子线程访问父线程的ThreadLocal变量

ThreadLocal 的实现原理

通过之前对并发编程的讲解,我们都知道了,对共享变量需要加锁,如下图所示:
Alt text

但这比较烦,不太懂锁怎么办?如果有一个方式创建一个变量,每个线程对它访问的时候,访问的是自己内部的变量那就好了。ThreadLocal 就可以干这件事!
ThreadLocal 是在 JDK 包里面提供的,它提供了线程本地变量,也就是如果你创建了一个 ThreadLocal 变量,那么访问这个变量的每个线程都会有这个变量的一个本地拷贝,多个线程操作这个变量的时候,实际是操作的自己本地内存里面的变量,从而避免了线程安全问题,创建一个 ThreadLocal 变量后每个线程会拷贝一个变量到自己本地内存,如下图:
Alt text

ThreadLocal 简单使用

本节来看下 ThreadLocal 如何使用,从而加深理解,本例子开启了两个线程,每个线程内部设置了本地变量的值,然后调用 print 函数打印当前本地变量的值,如果打印后调用了本地变量的 remove 方法则会删除本地内存中的该变量,代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
public class ThreadLocalTest {

//(1)打印函数
static void print(String str){
//1.1 打印当前线程本地内存中localVariable变量的值
System.out.println(str + ":" +localVariable.get());
//1.2 清除当前线程本地内存中localVariable变量
//localVariable.remove();
}
//(2) 创建ThreadLocal变量
static ThreadLocal<String> localVariable = new ThreadLocal<>();
public static void main(String[] args) {

//(3) 创建线程one
Thread threadOne = new Thread(new Runnable() {
public void run() {
//3.1 设置线程one中本地变量localVariable的值
localVariable.set("threadOne local variable");
//3.2 调用打印函数
print("threadOne");
//3.3打印本地变量值
System.out.println("threadOne remove after" + ":" +localVariable.get());

}
});
//(4) 创建线程two
Thread threadTwo = new Thread(new Runnable() {
public void run() {
//4.1 设置线程one中本地变量localVariable的值
localVariable.set("threadTwo local variable");
//4.2 调用打印函数
print("threadTwo");
//4.3打印本地变量值
System.out.println("threadTwo remove after" + ":" +localVariable.get());

}
});
//(5)启动线程
threadOne.start();
threadTwo.start();
}
1
2
3
4
5
6
运行结果:

threadOne:threadOne local variable
threadTwo:threadTwo local variable
threadOne remove after:threadOne local variable
threadTwo remove after:threadTwo local variable
  • 代码(2)创建了一个 ThreadLocal 变量;
  • 代码(3)、(4)分别创建了线程 one 和 two;
  • 代码(5)启动了两个线程;
  • 线程 one 中代码 3.1 通过 set 方法设置了 localVariable 的值,这个设置的其实是线程 one 本地内存中的一个拷贝,这个拷贝线程 two 是访问不了的。然后代码 3.2 调用了 print 函数,代码 1.1 通过 get 函数获取了当前线程(线程 one)本地内存中 localVariable 的值;

线程 two 执行类似线程 one。
解开代码 1.2 的注释后,再次运行,运行结果为:

1
2
3
4
threadOne:threadOne local variable
threadOne remove after:null
threadTwo:threadTwo local variable
threadTwo remove after:null

ThreadLocal 实现原理

首先看下 ThreadLocal 相关的类的类图结构。

Alt text

由这张图,我们可以看出,Thread类中存在两个变量,threadLocals 和 inheritableThreadLocals ,这两个变量类型都是ThreadLocalMap ,这个类实质是定制化的Hashmap。默认每个线程中这个两个变量都为 null,只有当前线程第一次调用了 ThreadLocal 的 set 或者 get 方法时候才会进行创建。

在这我们也能看出来,实际上ThreadLocal 并不存储变量,变量是在线程Thread的threadLocals 中的,也就是我们说的拷贝一份到线程的内存空间。ThreadLocal 就是一个工具壳,它通过 set 方法把 value 值放入调用线程的 threadLocals 里面存放起来,当调用线程调用它的 get 方法时候再从当前线程的 threadLocals变量里面拿出来使用。

如果调用线程一直不终止,那么这个本地变量会一直存放到调用线程的 threadLocals 变量里面,所以当不需要使用本地变量时候可以通过调用 ThreadLocal 变量的 remove 方法,从当前线程的 threadLocals 里面删除该本地变量。另外 Thread 里面的 threadLocals 为何设计为 map 结构呢?很明显是因为每个线程里面可以关联多个 ThreadLocal 变量。

下面简单分析下 ThreadLocal 的 set,get,remove 方法的实现逻辑:

  • void set(T value)
1
2
3
4
5
6
7
8
9
10
11
public void set(T value) {
//(1)获取当前线程
Thread t = Thread.currentThread();
//(2)当前线程作为key,去查找对应的线程变量,找到则设置
ThreadLocalMap map = getMap(t);
if (map != null)
map.set(this, value);
else
//(3)第一次调用则创建当前线程对应的HashMap
createMap(t, value);
}

如上代码(1)首先获取调用线程,然后使用当前线程作为参数调用了 getMap(t) 方法,getMap(Thread t) 代码如下:

1
2
3
ThreadLocalMap getMap(Thread t) {
return t.threadLocals;
}

可知 getMap(t) 所做的就是获取线程自己的变量 threadLocals,threadlocal 变量是绑定到了线程的成员变量里面。

如果 getMap(t) 返回不为空,则把 value 值设置进入到 threadLocals,也就是把当前变量值放入了当前线程的内存变量 threadLocals,threadLocals 是个 HashMap 结构,其中 key 就是当前 ThreadLocal 的实例对象引用,value 是通过 set 方法传递的值。

如果 getMap(t) 返回空那说明是第一次调用 set 方法,则创建当前线程的 threadLocals 变量,下面看 createMap(t, value) 里面做了啥呢?

1
2
3
 void createMap(Thread t, T firstValue) {
t.threadLocals = new ThreadLocalMap(this, firstValue);
}

可知实际就是创建当前线程的 threadLocals 变量。

  • T get()
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
public T get() {
//(4) 获取当前线程
Thread t = Thread.currentThread();
//(5)获取当前线程的threadLocals变量
ThreadLocalMap map = getMap(t);
//(6)如果threadLocals不为null,则返回对应本地变量值
if (map != null) {
ThreadLocalMap.Entry e = map.getEntry(this);
if (e != null) {
@SuppressWarnings("unchecked")
T result = (T)e.value;
return result;
}
}
//(7)threadLocals为空则初始化当前线程的threadLocals成员变量
return setInitialValue();
}

如上代码(4)首先获取当前线程实例,如果当前线程的 threadLocals 变量不为 null 则直接返回当前线程绑定的本地变量。否者执行代码(7)进行初始化,setInitialValue() 的代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
private T setInitialValue() {
//(8)初始化为null
T value = initialValue();
Thread t = Thread.currentThread();
ThreadLocalMap map = getMap(t);
//(9)如果当前线程的threadLocals变量不为空
if (map != null)
map.set(this, value);
else
//(10)如果当前线程的threadLocals变量为空
createMap(t, value);
return value;
}
1
2
3
protected T initialValue() {
return null;
}

如上代码如果当前线程的 threadLocals 变量不为空,则设置当前线程的本地变量值为 null,否者调用 createMap 创建当前线程的 createMap 变量。

  • void remove()
1
2
3
4
5
public void remove() {
ThreadLocalMap m = getMap(Thread.currentThread());
if (m != null)
m.remove(this);
}

如上代码,如果当前线程的 threadLocals 变量不为空,则删除当前线程中指定 ThreadLocal 实例的本地变量。

注:每个线程内部都有一个名字为 threadLocals 的成员变量,该变量类型为 HashMap,其中 key 为我们定义的 ThreadLocal 变量的 this 引用,value 则为我们 set 时候的值,每个线程的本地变量是存到线程自己的内存变量 threadLocals 里面的,如果当前线程一直不消失那么这些本地变量会一直存到,所以可能会造成内存泄露,所以使用完毕后要记得调用 ThreadLocal 的 remove 方法删除对应线程的 threadLocals 中的本地变量。

子线程中获取不到父线程中设置的 ThreadLocal 变量的值

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
public class ThreadLocalTest3 {

//(1) 创建线程变量
public static ThreadLocal<String> threadLocal = new ThreadLocal<String>();

public static void main(String[] args) {

//(2) 设置线程变量
threadLocal.set("hello world");
//(3) 启动子线程
Thread thread = new Thread(new Runnable() {
public void run() {
//(4)子线程输出线程变量的值
System.out.println("thread:" + threadLocal.get());

}
});
thread.start();

//(5)主线程输出线程变量值
System.out.println("main:" + threadLocal.get());

}
}

结果为:

1
2
main:hello world
thread:null

首先这个问题是应该的吗?当然是!我们刚刚说的ThreadLocal 是线程内部变量,子线程当然是不该获取到父线程的变量,但是我们还是想获取,比如有的时候有一些特殊需求咋办呢?能获取到吗?当然能了。

还有一个需要思考的是,这里怎么实现子线程获取不到父线程的ThreadLocal 的,我们可以看到在get方法部分首先获取的是当前线程,是通过这个来控制的。

InheritableThreadLocal 原理

为了解决上节的问题 InheritableThreadLocal 应运而生,InheritableThreadLocal 继承自 ThreadLocal,提供了一个特性,就是子线程可以访问到父线程中设置的本地变量。

看下它内部的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
public class InheritableThreadLocal<T> extends ThreadLocal<T> {

//(1)
protected T childValue(T parentValue) {
return parentValue;
}
//(2)
ThreadLocalMap getMap(Thread t) {
return t.inheritableThreadLocals;
}
//(3)
void createMap(Thread t, T firstValue) {
t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
}
}

如上代码可知 InheritableThreadLocal 继承了 ThreadLocal,并重写了三个方法。

  • 代码(3)可知 InheritableThreadLocal 重写了 createMap 方法,那么可知现在当第一次调用 set 方法时候创建的是当前线程的 inheritableThreadLocals 变量的实例而不再是 threadLocals。
  • 代码(2)可知当调用 get 方法获取当前线程的内部 map 变量时候,获取的是 inheritableThreadLocals 而不再是 threadLocals。

综上可知在 InheritableThreadLocal 的世界里,线程中的变量 inheritableThreadLocals 替代了 threadLocals。

  • 下面我们看下重写的代码(1)是何时被执行,以及如何实现的子线程可以访问父线程本地变量的。这个要从 Thread 创建的代码看起,Thread 的默认构造函数及 Thread.java 类的构造函数如下:
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
   public Thread(Runnable target) {
init(null, target, "Thread-" + nextThreadNum(), 0);
}
private void init(ThreadGroup g, Runnable target, String name,
long stackSize, AccessControlContext acc) {
...
//(4)获取当前线程
Thread parent = currentThread();
...
//(5)如果父线程的inheritableThreadLocals变量不为null
if (parent.inheritableThreadLocals != null)
//(6)设置子线程中的inheritableThreadLocals变量
this.inheritableThreadLocals =
ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
this.stackSize = stackSize;
tid = nextThreadID();
}

创建线程时候在构造函数里面会调用 init 方法,前面讲到了 inheritableThreadLocal 类 get,set 方法操作的是变量 inheritableThreadLocals,所以这里 inheritableThreadLocal 变量就不为 null,所以会执行代码(6),下面看下 createInheritedMap 代码:

1
2
3
static ThreadLocalMap createInheritedMap(ThreadLocalMap parentMap) {
return new ThreadLocalMap(parentMap);
}

可知 createInheritedMap 内部使用父线程的 inheritableThreadLocals 变量作为构造函数创建了一个新的 ThreadLocalMap 变量。然后赋值给了子线程的 inheritableThreadLocals 变量,那么下面看看 ThreadLocalMap 的构造函数里面做了什么:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
private ThreadLocalMap(ThreadLocalMap parentMap) {
Entry[] parentTable = parentMap.table;
int len = parentTable.length;
setThreshold(len);
table = new Entry[len];

for (int j = 0; j < len; j++) {
Entry e = parentTable[j];
if (e != null) {
@SuppressWarnings("unchecked")
ThreadLocal<Object> key = (ThreadLocal<Object>) e.get();
if (key != null) {
//(7)调用重写的方法
Object value = key.childValue(e.value);//返回e.value
Entry c = new Entry(key, value);
int h = key.threadLocalHashCode & (len - 1);
while (table[h] != null)
h = nextIndex(h, len);
table[h] = c;
size++;
}
}
}
}

如上代码所做的事情就是把父线程的 inheritableThreadLocals 成员变量的值复制到新的 ThreadLocalMap 对象,其中代码(7)InheritableThreadLocal 类重写的代码(1)也映入眼帘了。

总结:InheritableThreadLocal 类通过重写代码(2)和(3)让本地变量保存到了具体线程的 inheritableThreadLocals 变量里面,线程通过 InheritableThreadLocal 类实例的 set 或者 get 方法设置变量时候就会创建当前线程的 inheritableThreadLocals 变量。当父线程创建子线程时候,构造函数里面会把父线程中 inheritableThreadLocals 变量里面的本地变量拷贝一份复制到子线程的 inheritableThreadLocals 变量里面。

把上节代码(1)修改为:

1
2
//(1) 创建线程变量
public static ThreadLocal<String> threadLocal = new InheritableThreadLocal<String>();

运行结果为:

1
2
thread:hello world
main:hello world

可知现在可以从子线程中正常的获取到线程变量值了。

那么什么情况下需要子线程可以获取到父线程的 threadlocal 变量呢,情况还是蛮多的,比如存放用户登录信息的 threadlocal 变量,很有可能子线程中也需要使用用户登录信息,再比如一些中间件需要用统一的追踪 ID 把整个调用链路记录下来的情景。